{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "97e6eb98",
   "metadata": {},
   "source": [
    "## 13.7. 单发多框检测（SSD）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb9de0b8",
   "metadata": {},
   "source": [
    "### 13.7.1. 模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6eea46f9",
   "metadata": {},
   "source": [
    "#### 13.7.1.1. 类别预测层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e6901d6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "from d2l import mindspore as d2l\n",
    "\n",
    "\n",
    "def cls_predictor(num_inputs, num_anchors, num_classes):\n",
    "    return nn.Conv2d(num_inputs, num_anchors * (num_classes + 1),\n",
    "                     kernel_size=3, padding=1, pad_mode='pad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2814027",
   "metadata": {},
   "source": [
    "#### 13.7.1.2. 边界框预测层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ad634215",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bbox_predictor(num_inputs, num_anchors):\n",
    "    return nn.Conv2d(num_inputs, num_anchors * 4, kernel_size=3, padding=1, pad_mode='pad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a615a1fd",
   "metadata": {},
   "source": [
    "#### 13.7.1.3. 连结多尺度的预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "78859fe3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((2, 55, 20, 20), (2, 33, 10, 10))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def forward(x, block):\n",
    "    return block(x)\n",
    "\n",
    "Y1 = forward(d2l.zeros((2, 8, 20, 20)), cls_predictor(8, 5, 10))\n",
    "Y2 = forward(d2l.zeros((2, 16, 10, 10)), cls_predictor(16, 3, 10))\n",
    "Y1.shape, Y2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38e5e4ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_pred(pred):\n",
    "    return d2l.flatten(pred.permute(0, 2, 3, 1)) # flatten不改变0轴的size\n",
    "\n",
    "def concat_preds(preds):\n",
    "    return d2l.concat([flatten_pred(p) for p in preds], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c4dae410",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 25300)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concat_preds([Y1, Y2]).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "496894eb",
   "metadata": {},
   "source": [
    "#### 13.7.1.4. 高和宽减半块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e6285680",
   "metadata": {},
   "outputs": [],
   "source": [
    "def down_sample_blk(in_channels, out_channels):\n",
    "    blk = []\n",
    "    for _ in range(2):\n",
    "        blk.append(nn.Conv2d(in_channels, out_channels,\n",
    "                             kernel_size=3, padding=1, pad_mode='pad'))\n",
    "        blk.append(nn.BatchNorm2d(out_channels))\n",
    "        blk.append(nn.ReLU())\n",
    "        in_channels = out_channels\n",
    "    blk.append(nn.MaxPool2d(2, 2))\n",
    "    return nn.SequentialCell(*blk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "28d0fa9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 10, 10, 10)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forward(d2l.zeros((2, 3, 20, 20)), down_sample_blk(3, 10)).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66c49ee1",
   "metadata": {},
   "source": [
    "#### 13.7.1.5. 基本网络块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "62b33e23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 64, 32, 32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def base_net():\n",
    "    blk = []\n",
    "    num_filters = [3, 16, 32, 64]\n",
    "    for i in range(len(num_filters) - 1):\n",
    "        blk.append(down_sample_blk(num_filters[i], num_filters[i+1]))\n",
    "    return nn.SequentialCell(*blk)\n",
    "\n",
    "forward(d2l.zeros((2, 3, 256, 256)), base_net()).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad80f94b",
   "metadata": {},
   "source": [
    "#### 13.7.1.6. 完整的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6849b62b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_blk(i):\n",
    "    if i == 0:\n",
    "        blk = base_net()\n",
    "    elif i == 1:\n",
    "        blk = down_sample_blk(64, 128)\n",
    "    elif i == 4:\n",
    "        blk = nn.AdaptiveMaxPool2d((1,1))\n",
    "    else:\n",
    "        blk = down_sample_blk(128, 128)\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bc2e287d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):\n",
    "    Y = blk(X)\n",
    "    anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)   ##### 需要第四章\n",
    "    cls_preds = cls_predictor(Y)\n",
    "    bbox_preds = bbox_predictor(Y)\n",
    "    return (Y, anchors, cls_preds, bbox_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7ebc15c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],\n",
    "         [0.88, 0.961]]\n",
    "ratios = [[1, 2, 0.5]] * 5\n",
    "num_anchors = len(sizes[0]) + len(ratios[0]) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27b39788",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TinySSD(nn.Cell):\n",
    "    def __init__(self, num_classes, **kwargs):\n",
    "        super(TinySSD, self).__init__(**kwargs)\n",
    "        self.num_classes = num_classes\n",
    "        idx_to_in_channels = [64, 128, 128, 128, 128]\n",
    "        for i in range(5):\n",
    "            # 即赋值语句self.blk_i=get_blk(i)\n",
    "            setattr(self, f'blk_{i}', get_blk(i))\n",
    "            setattr(self, f'cls_{i}', cls_predictor(idx_to_in_channels[i],\n",
    "                                                    num_anchors, num_classes))\n",
    "            setattr(self, f'bbox_{i}', bbox_predictor(idx_to_in_channels[i],\n",
    "                                                      num_anchors))\n",
    "\n",
    "    def construct(self, X):\n",
    "        anchors, cls_preds, bbox_preds = [None] * 5, [None] * 5, [None] * 5\n",
    "        for i in range(5):\n",
    "            # getattr(self,'blk_%d'%i)即访问self.blk_i\n",
    "            X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(\n",
    "                X, getattr(self, f'blk_{i}'), sizes[i], ratios[i],\n",
    "                getattr(self, f'cls_{i}'), getattr(self, f'bbox_{i}'))\n",
    "        anchors = d2l.concat(anchors, axis=1)\n",
    "        cls_preds = concat_preds(cls_preds)\n",
    "        cls_preds = cls_preds.reshape(\n",
    "            cls_preds.shape[0], -1, self.num_classes + 1)\n",
    "        bbox_preds = concat_preds(bbox_preds)\n",
    "        return anchors, cls_preds, bbox_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f6b40409",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output anchors: (1, 5444, 4)\n",
      "output class preds: (32, 5444, 2)\n",
      "output bbox preds: (32, 21776)\n"
     ]
    }
   ],
   "source": [
    "net = TinySSD(num_classes=1)\n",
    "X = d2l.zeros((32, 3, 256, 256))\n",
    "anchors, cls_preds, bbox_preds = net(X)\n",
    "\n",
    "print('output anchors:', anchors.shape)\n",
    "print('output class preds:', cls_preds.shape)\n",
    "print('output bbox preds:', bbox_preds.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38215c12",
   "metadata": {},
   "source": [
    "### 13.7.2. 训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e86cd5e",
   "metadata": {},
   "source": [
    "#### 13.7.2.1. 读取数据集和初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "21855eb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read 1000 training examples\n",
      "read 100 validation examples\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "train_iter, _ = d2l.load_data_bananas(batch_size) # "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5948d3c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = TinySSD(num_classes=1)\n",
    "trainer = nn.SGD(net.trainable_params(), learning_rate=0.2, weight_decay=5e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26773b4e",
   "metadata": {},
   "source": [
    "#### 13.7.2.2. 定义损失函数和评价函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "19fff2c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_loss = nn.CrossEntropyLoss(reduction='none')\n",
    "bbox_loss = nn.L1Loss(reduction='none')\n",
    "\n",
    "def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):\n",
    "    batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]\n",
    "    cls = cls_loss(cls_preds.reshape(-1, num_classes),\n",
    "                   cls_labels.reshape(-1)).reshape(batch_size, -1).mean(axis=1)\n",
    "    bbox = bbox_loss(bbox_preds * bbox_masks,\n",
    "                     bbox_labels * bbox_masks).mean(axis=1)\n",
    "    return cls + bbox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a684727f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cls_eval(cls_preds, cls_labels):\n",
    "    # 由于类别预测结果放在最后一维，argmax需要指定最后一维。\n",
    "    return float((cls_preds.argmax(axis=-1).type(\n",
    "        cls_labels.dtype) == cls_labels).sum())\n",
    "\n",
    "def bbox_eval(bbox_preds, bbox_labels, bbox_masks):\n",
    "#     return float((d2l.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())\n",
    "    return float(((bbox_labels - bbox_preds) * bbox_masks).abs().sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "259d2abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import ops\n",
    "def box_iou(boxes1, boxes2):\n",
    "    \"\"\"计算两个锚框或边界框列表中成对的交并比\"\"\"\n",
    "    box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n",
    "                              (boxes[:, 3] - boxes[:, 1]))\n",
    "    # boxes1,boxes2,areas1,areas2的形状:\n",
    "    # boxes1：(boxes1的数量,4),\n",
    "    # boxes2：(boxes2的数量,4),\n",
    "    # areas1：(boxes1的数量,),\n",
    "    # areas2：(boxes2的数量,)\n",
    "    areas1 = box_area(boxes1)\n",
    "    areas2 = box_area(boxes2)\n",
    "    # inter_upperlefts,inter_lowerrights,inters的形状:\n",
    "    # (boxes1的数量,boxes2的数量,2)\n",
    "    inter_upperlefts = ops.maximum(boxes1[:, None, :2], boxes2[:, :2])\n",
    "    inter_lowerrights = ops.minimum(boxes1[:, None, 2:], boxes2[:, 2:])\n",
    "    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n",
    "    # inter_areasandunion_areas的形状:(boxes1的数量,boxes2的数量)\n",
    "    inter_areas = inters[:, :, 0] * inters[:, :, 1]\n",
    "    union_areas = areas1[:, None] + areas2 - inter_areas\n",
    "    return inter_areas / union_areas\n",
    "\n",
    "def assign_anchor_to_bbox(ground_truth, anchors, iou_threshold=0.5):\n",
    "    \"\"\"将最接近的真实边界框分配给锚框\"\"\"\n",
    "    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n",
    "    # 位于第i行和第j列的元素x_ij是锚框i和真实边界框j的IoU\n",
    "    jaccard = box_iou(anchors, ground_truth)\n",
    "    # 对于每个锚框，分配的真实边界框的张量\n",
    "    anchors_bbox_map = ops.full((num_anchors,), -1, dtype=mindspore.int64)\n",
    "    # 根据阈值，决定是否分配真实边界框\n",
    "    indices, max_ious = ops.max(jaccard, axis=1)\n",
    "    anc_i = ops.nonzero(max_ious >= iou_threshold).reshape(-1)\n",
    "    box_j = ops.masked_select(indices, max_ious >= iou_threshold)\n",
    "    anchors_bbox_map[anc_i] = box_j \n",
    "    col_discard = ops.full((num_anchors,), -1)\n",
    "    row_discard = ops.full((num_gt_boxes,), -1)\n",
    "    for _ in range(num_gt_boxes):\n",
    "        max_idx = ops.argmax(jaccard)\n",
    "        box_idx = (max_idx % num_gt_boxes).long()\n",
    "        anc_idx = (max_idx / num_gt_boxes).long()\n",
    "        anchors_bbox_map[anc_idx] = box_idx\n",
    "        jaccard[:, box_idx] = col_discard\n",
    "        jaccard[anc_idx, :] = row_discard\n",
    "    return anchors_bbox_map\n",
    "\n",
    "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n",
    "    \"\"\"对锚框偏移量的转换\"\"\"\n",
    "    c_anc = d2l.box_corner_to_center(anchors)\n",
    "    c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n",
    "    offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n",
    "    offset_wh = 5 * ops.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n",
    "    offset = ops.concat([offset_xy, offset_wh], axis=1)\n",
    "    return offset\n",
    "\n",
    "def multibox_target(anchors, labels):\n",
    "    \"\"\"使用真实边界框标记锚框\"\"\"\n",
    "    batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n",
    "    batch_offset, batch_mask, batch_class_labels = [], [], []\n",
    "    num_anchors = anchors.shape[0]\n",
    "    for i in range(batch_size):\n",
    "        label = labels[i, :, :]\n",
    "        anchors_bbox_map = assign_anchor_to_bbox(\n",
    "            label[:, 1:], anchors)\n",
    "        bbox_mask = ops.tile((anchors_bbox_map >= 0).float().unsqueeze(-1), (1, 4))\n",
    "        # 将类标签和分配的边界框坐标初始化为零\n",
    "        class_labels = ops.zeros(num_anchors, dtype=mindspore.int32)\n",
    "        assigned_bb = ops.zeros((num_anchors, 4), dtype=mindspore.float32)\n",
    "        # 使用真实边界框来标记锚框的类别。\n",
    "        # 如果一个锚框没有被分配，标记其为背景（值为零）\n",
    "        indices_true = ops.nonzero(anchors_bbox_map >= 0)\n",
    "        bb_idx = anchors_bbox_map[indices_true]\n",
    "        class_labels[indices_true] = label[bb_idx, 0].long() + 1 \n",
    "        assigned_bb[indices_true] = label[bb_idx, 1:]\n",
    "        # 偏移量转换\n",
    "        offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n",
    "        batch_offset.append(offset.reshape(-1))\n",
    "        batch_mask.append(bbox_mask.reshape(-1))\n",
    "        batch_class_labels.append(class_labels)\n",
    "    bbox_offset = ops.stack(batch_offset)\n",
    "    bbox_mask = ops.stack(batch_mask)\n",
    "    class_labels = ops.stack(batch_class_labels)\n",
    "    return (bbox_offset, bbox_mask, class_labels)\n",
    "\n",
    "# print(anchors.shape, Y.shape) # (1, 5444, 4) (32, 1, 5)\n",
    "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d80ea58d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(32, 3, 256, 256) (32, 1, 5)\n"
     ]
    }
   ],
   "source": [
    "#forward\n",
    "for X, Y in train_iter:\n",
    "    print(X.shape, Y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ca1c0d14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output anchors: (1, 5444, 4)\n",
      "output class preds: (32, 5444, 2)\n",
      "output bbox preds: (32, 21776)\n"
     ]
    }
   ],
   "source": [
    "anchors, cls_preds, bbox_preds = net(X)\n",
    "print('output anchors:', anchors.shape)\n",
    "print('output class preds:', cls_preds.shape)\n",
    "print('output bbox preds:', bbox_preds.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d1f7601f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.896 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n",
      "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.965 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bbox_labels: (32, 21776) (32, 21776)\n",
      "bbox_masks: (32, 21776) Float32\n",
      "cls_labels: (32, 5444) Int32\n"
     ]
    }
   ],
   "source": [
    "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n",
    "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n",
    "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n",
    "print('cls_labels:', cls_labels.shape, cls_labels.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "993302d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[], dtype=Float32, value= 0.69856)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks).mean()\n",
    "l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c78354a",
   "metadata": {},
   "source": [
    "# BUG 直接运行的话，这里会卡住"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e0d0317",
   "metadata": {},
   "outputs": [],
   "source": [
    "#forward\n",
    "for X, Y in train_iter:\n",
    "    print(X.shape, Y.shape)\n",
    "    anchors, cls_preds, bbox_preds = net(X)\n",
    "    print('output anchors:', anchors.shape)\n",
    "    print('output class preds:', cls_preds.shape)\n",
    "    print('output bbox preds:', bbox_preds.shape)\n",
    "    \n",
    "    bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n",
    "    print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n",
    "    print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n",
    "    print('cls_labels:', cls_labels.shape, cls_labels.dtype)    \n",
    "    \n",
    "    l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n",
    "                  bbox_masks)\n",
    "    print(l)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ef20694",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y = next(train_iter)\n",
    "print(X.shape, Y.shape)\n",
    "anchors, cls_preds, bbox_preds = net(X)\n",
    "print('output anchors:', anchors.shape)\n",
    "print('output class preds:', cls_preds.shape)\n",
    "print('output bbox preds:', bbox_preds.shape)\n",
    "\n",
    "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n",
    "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n",
    "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n",
    "print('cls_labels:', cls_labels.shape, cls_labels.dtype)    \n",
    "\n",
    "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n",
    "              bbox_masks)\n",
    "print(l)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6912bb2d",
   "metadata": {},
   "source": [
    "#### 13.7.2.3. 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3870f282",
   "metadata": {},
   "outputs": [],
   "source": [
    "# num_epochs, timer = 20, d2l.Timer()\n",
    "# animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
    "#                         legend=['class error', 'bbox mae'])\n",
    "\n",
    "# def forward_fn(X, Y):\n",
    "#     # 生成多尺度的锚框，为每个锚框预测类别和偏移量\n",
    "#     anchors, cls_preds, bbox_preds = net(X)\n",
    "#     # 为每个锚框标注类别和偏移量\n",
    "#     bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y) # d2l.\n",
    "#     # 根据类别和偏移量的预测和标注值计算损失函数\n",
    "#     l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n",
    "#                   bbox_masks).mean()\n",
    "#     return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n",
    "    \n",
    "# grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True)\n",
    "    \n",
    "# def train_step(inputs, targets):\n",
    "#     (l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels), grads = grad_fn(inputs, targets)\n",
    "#     trainer(grads)\n",
    "#     return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n",
    "    \n",
    "# for epoch in range(num_epochs):\n",
    "#     # 训练精确度的和，训练精确度的和中的示例数\n",
    "#     # 绝对误差的和，绝对误差的和中的示例数\n",
    "#     metric = d2l.Accumulator(4)\n",
    "#     net.set_train()\n",
    "#     for features, target in train_iter:\n",
    "#         timer.start()\n",
    "#         print(epoch)\n",
    "#         l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels = train_step(features, target)\n",
    "#         metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),\n",
    "#                    bbox_eval(bbox_preds, bbox_labels, bbox_masks),\n",
    "#                    bbox_labels.numel())\n",
    "#     cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]\n",
    "#     animator.add(epoch + 1, (cls_err, bbox_mae))\n",
    "# print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')\n",
    "# print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on '\n",
    "#       f'{str(device)}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
